# coding=utf-8
import os
import sys
sys.path.append("..")
sys.path.append("../utils")
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from torch.utils.data import Dataset, DataLoader
import config.yolov3_config_voc as cfg
import cv2
import numpy as np
import random
import utils.data_augment as dataAug
import utils.tools as tools
import torch.nn.functional as F
import torchvision.transforms as transforms

# 根据txt文件制作ground truth
CLASS_NUM = 20  # 使用其他训练集需要更改


class PLNDataset(Dataset):
    """
    点链接网络(Point Linking Network)数据集类
    
    用于加载和处理目标检测数据，生成PLN模型所需的标签格式
    """
    
    def __init__(self, anno_file_type, img_size=448):
        """
        初始化数据集
        
        参数:
            anno_file_type (str): 'train' or 'test'
            img_size (int): 图像大小，默认448
        """
        self.img_size = img_size
        self.classes = cfg.DATA["CLASSES"]
        self.num_classes = len(self.classes)
        self.class_to_id = dict(zip(self.classes, range(self.num_classes)))
        
        # 存储数据的列表
        self.fnames = []    # 文件名
        self.boxes = []     # 边界框坐标 [x1, y1, x2, y2]
        self.labels = []    # 类别标签
        
        # 网络相关参数
        self.S = 14         # 网格大小
        self.B = 2          # 每个位置预测的边界框数量
        self.C = CLASS_NUM  # 类别数量
        
        self.mean = (0, 0, 0)
        
        # 加载标注
        self.__annotations = self.__load_annotations(anno_file_type)
        self.__parse_all_annotations()
        
        # 样本总数
        self.num_samples = len(self.boxes)

    def __len__(self):
        return len(self.__annotations)

    def __getitem__(self, item):
        # 获取图像文件名和对应的边界框、标签
        fname = self.fnames[item]
        boxes = self.boxes[item].clone() if len(self.boxes[item]) > 0 else torch.zeros((0, 4))
        labels = self.labels[item].clone() if len(self.labels[item]) > 0 else torch.zeros((0,), dtype=torch.long)

        # 转换路径格式（从Linux绝对路径转换为相对路径）
        if fname.startswith('/home/haotiantan/home/tht/YOLOV3/'):
            # 提取相对路径部分
            relative_path = fname.replace('/home/haotiantan/home/tht/YOLOV3/', '')
            fname = relative_path.replace('/', os.sep)  # 使用当前系统的路径分隔符
        
        # 读取图像
        img = cv2.imread(fname)
        if img is None:
            # 如果还是找不到，尝试其他可能的路径
            alt_paths = [
                fname.replace('data/VOCdevkit', 'data/VOCdevkit'),
                fname.replace('VOCdevkit', 'data/VOCdevkit'),
                os.path.join('data', 'VOCdevkit', os.path.basename(fname))
            ]
            for alt_path in alt_paths:
                img = cv2.imread(alt_path)
                if img is not None:
                    fname = alt_path
                    break
        
        assert img is not None, f'File Not Found: {fname}'
        
        # 获取图像尺寸
        h, w, _ = img.shape
        
        # 如果有边界框，进行归一化
        if len(boxes) > 0:
            boxes = boxes / torch.Tensor([w, h, w, h]).expand_as(boxes)  # 归一化到[0,1]

        img = cv2.resize(img, (self.img_size, self.img_size))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 创建PLN格式的标签
        label_generator = LabelGenerator("", "", "", 0, self.img_size, 14)
        target = []
        for branch in range(4):  # 四个分支分别对应四个角点
            t = label_generator.generate_label(branch, boxes, labels)
            target.append(t)

        # 堆叠成[4, 14, 14, 204]的张量
        target = torch.stack(target)

        # 应用数据增强
        if cfg.TRAIN["AUGMENT"]:
            img = self._apply_augmentations(img)

        # 应用图像变换
        img = torch.from_numpy(img).float().permute(2, 0, 1) / 255.0  # HWC->CHW, 归一化

        return img, target

    def __load_annotations(self, anno_type):
        assert anno_type in ['train', 'test'], "You must choice one of the 'train' or 'test' for anno_type parameter"
        anno_path = os.path.join(cfg.PROJECT_PATH, 'data', anno_type + ".txt")
        with open(anno_path, 'r') as f:
            annotations = list(filter(lambda x:len(x)>0, f.readlines()))
        assert len(annotations)>0, "No images found in {}".format(anno_path)
        return annotations

    def __parse_all_annotations(self):
        """解析所有标注文件"""
        for annotation in self.__annotations:
            anno = annotation.strip().split(' ')
            img_path = anno[0]
            self.fnames.append(img_path)
            
            if len(anno) > 1:
                # 临时存储当前图像的边界框和标签
                boxes = []
                labels = []
                
                # 提取每个边界框的坐标和类别 (格式: xmin,ymin,xmax,ymax,class)
                for i in range(1, len(anno)):
                    try:
                        # 解析逗号分隔的坐标和类别
                        bbox_info = anno[i].split(',')
                        if len(bbox_info) == 5:  # 确保有5个值: x1,y1,x2,y2,class
                            x1 = float(bbox_info[0])  # xmin
                            y1 = float(bbox_info[1])  # ymin
                            x2 = float(bbox_info[2])  # xmax
                            y2 = float(bbox_info[3])  # ymax
                            c = int(bbox_info[4])     # 类别 (0-19)
                            
                            boxes.append([x1, y1, x2, y2])
                            labels.append(c)
                    except (ValueError, IndexError) as e:
                        print(f"Warning: Skipping invalid bbox annotation: {anno[i]} in {img_path}")
                        continue
                
                # 转换为Tensor并添加到数据列表
                if len(boxes) > 0:
                    self.boxes.append(torch.Tensor(boxes))
                    self.labels.append(torch.LongTensor(labels))
                else:
                    self.boxes.append(torch.zeros((0, 4)))
                    self.labels.append(torch.zeros((0,), dtype=torch.long))
            else:
                # 没有边界框的情况
                self.boxes.append(torch.zeros((0, 4)))
                self.labels.append(torch.zeros((0,), dtype=torch.long))

    def BGR2RGB(self, img):
        """BGR格式转RGB格式"""
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def subtract_mean(self, img, mean=None):
        """
        图像减去均值
        
        参数:
            img: 输入图像
            mean: 均值，默认使用类初始化时设定的均值
        """
        if mean is None:
            mean = self.mean
        mean = np.array(mean, dtype=np.float32)
        return img - mean

    def _apply_augmentations(self, img):
        """
        应用数据增强策略
        """
        import utils.data_augment as dataAug
        
        # 创建虚拟边界框用于数据增强（PLN不需要真实的bbox）
        dummy_bboxes = np.array([[0, 0, 1, 1]])  # 虚拟边界框
        
        # 应用各种数据增强
        img, _ = dataAug.RandomBrightness(p=0.5, delta=0.15)(img, dummy_bboxes)
        img, _ = dataAug.RandomContrast(p=0.5, lower=0.8, upper=1.2)(img, dummy_bboxes)
        img, _ = dataAug.RandomSaturation(p=0.3, lower=0.8, upper=1.2)(img, dummy_bboxes)
        img, _ = dataAug.RandomNoise(p=0.2, noise_std=0.01)(img, dummy_bboxes)
        
        return img


class LabelGenerator:
    """
    PLN标签生成器
    
    负责将边界框和类别标签转换为PLN格式的标签
    """
    
    def __init__(self, train_dir_obj, test_dir_obj, loader_type, seed, pic_width, S=14):
        """
        初始化标签生成器
        
        参数:
            train_dir_obj: 训练目录
            test_dir_obj: 测试目录
            loader_type: 加载类型
            seed: 随机种子
            pic_width: 图像宽度
            S: 网格大小，默认为14
        """
        self.loader_type = loader_type
        self.s = S                   # 网格大小
        self.classes = 20            # 类别数量
        self.B = 2                   # 每个位置的边界框数量
        self.infinite = 0.01         # 用于类别标签的合理数值，进一步降低
        self.pic_width = pic_width   # 图像宽度
        
        # 设置随机种子
        torch.manual_seed(seed)

    def generate_label(self, branch, boxes, labels):
        """
        生成指定分支的PLN标签
        
        参数:
            branch (int): 分支索引 (0-3)，对应左下、左上、右下、右上四个角点
            boxes (Tensor): 归一化后的边界框坐标
            labels (Tensor): 类别标签
            
        返回:
            Tensor: 形状为[14,14,204]的标签张量
        """
        if len(boxes) == 0:
            # 如果没有边界框，返回全零张量
            return torch.zeros((self.s, self.s, (1 + 2 + 2*self.s + self.classes) * self.B * 2))
        
        # 计算类别相关张量 Q
        Q_tensor, Q_ct_tensor = self.generate_class_tensors(branch, boxes, labels)
        
        # 计算点存在性张量 P
        P_tensor, P_ct_tensor = self.generate_point_tensors(branch, boxes)
        
        # 计算连接关系张量 L
        Link_ct_list, Link_cr_list = self.generate_link_tensors(branch, boxes)
        
        # 计算相对位置张量 x
        x_tensor, x_ct_tensor = self.generate_position_tensors(branch, boxes)

        # 合并所有张量生成最终标签
        final_tensor = self.combine_tensors(
            Q_tensor, Q_ct_tensor, P_tensor, P_ct_tensor, 
            Link_ct_list, Link_cr_list, x_tensor, x_ct_tensor
        )
        
        return final_tensor

    def generate_point_tensors(self, branch, boxes):
        """
        生成点存在性张量
        
        参数:
            branch (int): 分支索引 (0-3)
            boxes (Tensor): 边界框坐标
            
        返回:
            tuple: (posi, posi_ct) 角点和中心点的存在性张量
        """
        posi = []
        posi_ct = []

        # 初始化张量
        p_tensor = torch.zeros((self.s, self.s, 1))
        p_ct_tensor = torch.zeros((self.s, self.s, 1))
        p_tensor1 = torch.zeros((self.s, self.s, 1))
        p_ct_tensor1 = torch.zeros((self.s, self.s, 1))
        
        for box in boxes:
            xmin, ymin, xmax, ymax = box
            
            # 防止坐标值过大
            if xmax >= 0.99:
                xmax -= 0.001
            if ymax >= 0.99:
                ymax -= 0.001
                
            # 根据分支索引设置不同角点
            if branch == 0:  # 左下角
                p_tensor[int(ymax * self.s), int(xmin * self.s)] = 1
                p_tensor1[int(ymax * self.s), int(xmin * self.s)] = 1
            elif branch == 1:  # 左上角
                p_tensor[int(ymin * self.s), int(xmin * self.s)] = 1
                p_tensor1[int(ymin * self.s), int(xmin * self.s)] = 1
            elif branch == 2:  # 右下角
                p_tensor[int(ymax * self.s), int(xmax * self.s)] = 1
                p_tensor1[int(ymax * self.s), int(xmax * self.s)] = 1
            elif branch == 3:  # 右上角
                p_tensor[int(ymin * self.s), int(xmax * self.s)] = 1
                p_tensor1[int(ymin * self.s), int(xmax * self.s)] = 1
                
            # 设置中心点
            center_y = int((ymin + ymax) / 2 * self.s)
            center_x = int((xmin + xmax) / 2 * self.s)
            p_ct_tensor[center_y, center_x] = 1
            p_ct_tensor1[center_y, center_x] = 1

        posi.append(p_tensor)
        posi.append(p_tensor1)
        posi_ct.append(p_ct_tensor)
        posi_ct.append(p_ct_tensor1)

        return posi, posi_ct

    def generate_position_tensors(self, branch, boxes):
        """
        生成相对位置张量
        
        参数:
            branch (int): 分支索引 (0-3)
            boxes (Tensor): 边界框坐标
            
        返回:
            tuple: (pos_list, pos_ct_list) 角点和中心点的相对位置张量
        """
        pos_list = []
        pos_ct_list = []

        # 初始化张量
        pos_tensor = torch.zeros((self.s, self.s, 2))
        pos_ct_tensor = torch.zeros((self.s, self.s, 2))
        pos_tensor1 = torch.zeros((self.s, self.s, 2))
        pos_ct_tensor1 = torch.zeros((self.s, self.s, 2))
        
        for box in boxes:
            xmin, ymin, xmax, ymax = box
            
            # 防止坐标值过大
            if xmax >= 0.99:
                xmax -= 0.001
            if ymax >= 0.99:
                ymax -= 0.001
                
            # 根据分支索引设置不同角点的相对位置
            if branch == 0:  # 左下角
                y_idx, x_idx = int(ymax * self.s), int(xmin * self.s)
                y_rel = ymax * self.s - y_idx
                x_rel = xmin * self.s - x_idx
                pos_tensor[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
                pos_tensor1[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
            elif branch == 1:  # 左上角
                y_idx, x_idx = int(ymin * self.s), int(xmin * self.s)
                y_rel = ymin * self.s - y_idx
                x_rel = xmin * self.s - x_idx
                pos_tensor[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
                pos_tensor1[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
            elif branch == 2:  # 右下角
                y_idx, x_idx = int(ymax * self.s), int(xmax * self.s)
                y_rel = ymax * self.s - y_idx
                x_rel = xmax * self.s - x_idx
                pos_tensor[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
                pos_tensor1[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
            elif branch == 3:  # 右上角
                y_idx, x_idx = int(ymin * self.s), int(xmax * self.s)
                y_rel = ymin * self.s - y_idx
                x_rel = xmax * self.s - x_idx
                pos_tensor[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
                pos_tensor1[y_idx, x_idx] = torch.tensor([y_rel, x_rel])
                
            # 计算中心点相对位置
            ctx_p = (xmin * self.s + xmax * self.s) / 2
            cty_p = (ymin * self.s + ymax * self.s) / 2
            ctx = int(ctx_p)
            cty = int(cty_p)
            pos_ct_tensor[cty, ctx] = torch.tensor([ctx_p - ctx, cty_p - cty])
            pos_ct_tensor1[cty, ctx] = torch.tensor([ctx_p - ctx, cty_p - cty])

        pos_list.append(pos_tensor)
        pos_list.append(pos_tensor1)
        pos_ct_list.append(pos_ct_tensor)
        pos_ct_list.append(pos_ct_tensor1)
        
        return pos_list, pos_ct_list

    def generate_boxes_tensor(self, boxes):
        """
        处理边界框坐标，生成PLN格式的边界框张量
        
        参数:
            boxes (Tensor): 边界框坐标
            
        返回:
            Tensor: 形状为[N, 4, 2, 2]的边界框张量
        """
        if len(boxes) == 0:
            return torch.zeros((2, 4, 2, 2))
            
        boxes_list = []
        points = []
        
        for box in boxes:
            xmin, ymin, xmax, ymax = box
            
            # 防止坐标值过大
            if xmax >= 0.99:
                xmax -= 0.001
            if ymax >= 0.99:
                ymax -= 0.001
                
            # 计算中心坐标
            center_x = (xmin + xmax) / 2
            center_y = (ymin + ymax) / 2
            
            # 生成中心点与四个角点的连接关系
            # 中心点到左下角
            points.append(torch.tensor([
                (center_x * self.s, center_y * self.s), 
                (xmin * self.s, ymax * self.s)
            ]))
            
            # 中心点到左上角
            points.append(torch.tensor([
                (center_x * self.s, center_y * self.s), 
                (xmin * self.s, ymin * self.s)
            ]))
            
            # 中心点到右下角
            points.append(torch.tensor([
                (center_x * self.s, center_y * self.s), 
                (xmax * self.s, ymax * self.s)
            ]))
            
            # 中心点到右上角
            points.append(torch.tensor([
                (center_x * self.s, center_y * self.s), 
                (xmax * self.s, ymin * self.s)
            ]))
            
        points_tensor = torch.stack(points)
        boxes_list.append(points_tensor)
        
        stacked_tensor = torch.stack(boxes_list).squeeze(0)
        # 重塑为[N, 4, 2, 2]形状，表示N个边界框，每个有4个角点，每个点是2维坐标
        boxes_reshaped = stacked_tensor.view(int(stacked_tensor.shape[0] / 4), 4, 2, 2)
        
        return boxes_reshaped

    def generate_link_weights(self, center_pt, corner_pt):
        """
        生成中心点到角点的连接权重
        
        参数:
            center_pt (Tensor): 中心点坐标
            corner_pt (Tensor): 角点坐标
            
        返回:
            tuple: (L_ct, L_cr) 连接权重张量
        """
        # 初始化x和y方向的权重张量
        Lx_ct = torch.zeros(self.s)
        Ly_ct = torch.zeros(self.s)
        
        # 设置中心点到角点的连接权重
        Lx_ct[int(corner_pt[0])] = 1.0
        Ly_ct[int(corner_pt[1])] = 1.0
        
        # 连接x和y方向的权重
        L_ct = torch.cat((Lx_ct, Ly_ct))
        
        # 初始化角点到中心点的权重张量
        Lx_cr = torch.zeros(self.s)
        Ly_cr = torch.zeros(self.s)
        
        # 设置角点到中心点的连接权重
        Lx_cr[int(center_pt[0])] = 1.0
        Ly_cr[int(center_pt[1])] = 1.0
        
        # 连接x和y方向的权重
        L_cr = torch.cat((Lx_cr, Ly_cr))
        
        return L_ct, L_cr

    def generate_link_tensors(self, branch, boxes):
        """
        生成连接关系张量
        
        参数:
            branch (int): 分支索引 (0-3)
            boxes (Tensor): 边界框坐标
            
        返回:
            tuple: (Link_ct_list, Link_cr_list) 连接关系张量列表
        """
        Link_ct_list = []
        Link_cr_list = []
        
        # 初始化张量
        Link_tmp_cr = torch.zeros((self.s, self.s, 2 * self.s))
        Link_tmp_ct = torch.zeros((self.s, self.s, 2 * self.s))
        Link_tmp_cr1 = torch.zeros((self.s, self.s, 2 * self.s))
        Link_tmp_ct1 = torch.zeros((self.s, self.s, 2 * self.s))
        
        # 生成边界框张量   [N, 4, 2, 2]
        box_tensor = self.generate_boxes_tensor(boxes)
        
        # 提取指定分支的边界框
        box_branch = box_tensor[:, branch, ...]
        
        # 为每个边界框生成连接权重
        for obj_idx, obj_data in enumerate(box_branch):
            # 获取中心点和角点的连接权重
            L_ct, L_cr = self.generate_link_weights(obj_data[0], obj_data[1])
            
            # 在对应位置设置权重
            corner_y, corner_x = int(obj_data[1][1]), int(obj_data[1][0])
            center_y, center_x = int(obj_data[0][1]), int(obj_data[0][0])
            
            Link_tmp_cr[corner_y, corner_x] = L_cr
            Link_tmp_cr1[corner_y, corner_x] = L_cr
            Link_tmp_ct[center_y, center_x] = L_ct
            Link_tmp_ct1[center_y, center_x] = L_ct
        
        # 不应用softmax，直接使用原始值
        # 这样可以让模型学习到更清晰的链接关系
        Link_tmp_ct = Link_tmp_ct.clone()
        Link_tmp_ct1 = Link_tmp_ct1.clone()
        Link_tmp_cr = Link_tmp_cr.clone()
        Link_tmp_cr1 = Link_tmp_cr1.clone()
        
        Link_ct_list.append(Link_tmp_ct)
        Link_ct_list.append(Link_tmp_ct1)
        Link_cr_list.append(Link_tmp_cr)
        Link_cr_list.append(Link_tmp_cr1)
        
        return Link_ct_list, Link_cr_list

    def generate_class_tensors(self, branch, boxes, labels):
        """
        生成类别预测张量
        
        参数:
            branch (int): 分支索引 (0-3)
            boxes (Tensor): 边界框坐标
            labels (Tensor): 类别标签
            
        返回:
            tuple: (Q_list, Q_ct_list) 类别预测张量列表
        """
        Q_list = []
        Q_ct_list = []
        
        # 初始化张量
        Q_tensor = torch.zeros((self.s, self.s, self.classes))
        Q_ct_tensor = torch.zeros((self.s, self.s, self.classes))
        Q_tensor1 = torch.zeros((self.s, self.s, self.classes))
        Q_ct_tensor1 = torch.zeros((self.s, self.s, self.classes))
        
        # 为每个边界框设置类别预测
        for idx_ele, class_label in enumerate(labels):
            box = boxes[idx_ele]
            xmin, ymin, xmax, ymax = box
            
            # 防止坐标值过大
            if xmax >= 0.99:
                xmax -= 0.001
            if ymax >= 0.99:
                ymax -= 0.001
            
            # 根据分支索引设置不同角点的类别
            if branch == 0:  # 左下角
                Q_tensor[int(ymax * self.s), int(xmin * self.s), class_label] = self.infinite
                Q_tensor1[int(ymax * self.s), int(xmin * self.s), class_label] = self.infinite
            elif branch == 1:  # 左上角
                Q_tensor[int(ymin * self.s), int(xmin * self.s), class_label] = self.infinite
                Q_tensor1[int(ymin * self.s), int(xmin * self.s), class_label] = self.infinite
            elif branch == 2:  # 右下角
                Q_tensor[int(ymax * self.s), int(xmax * self.s), class_label] = self.infinite
                Q_tensor1[int(ymax * self.s), int(xmax * self.s), class_label] = self.infinite
            elif branch == 3:  # 右上角
                Q_tensor[int(ymin * self.s), int(xmax * self.s), class_label] = self.infinite
                Q_tensor1[int(ymin * self.s), int(xmax * self.s), class_label] = self.infinite
                
            # 设置中心点的类别
            center_y = int((ymin + ymax) / 2 * self.s)
            center_x = int((xmin + xmax) / 2 * self.s)
            Q_ct_tensor[center_y, center_x, class_label] = self.infinite
            Q_ct_tensor1[center_y, center_x, class_label] = self.infinite
        
        # 不应用softmax，直接使用原始值
        # 这样可以让模型学习到更清晰的类别区分
        Q_ct_tensor = Q_ct_tensor.clone()
        Q_tensor = Q_tensor.clone()
        Q_ct_tensor1 = Q_ct_tensor1.clone()
        Q_tensor1 = Q_tensor1.clone()
        
        Q_list.append(Q_tensor)
        Q_list.append(Q_tensor1)
        Q_ct_list.append(Q_ct_tensor)
        Q_ct_list.append(Q_ct_tensor1)
        
        return Q_list, Q_ct_list

    def combine_tensors(self, Q_list, Q_ct_list, posi_list, posi_ct_list, 
                     Link_ct_list, Link_cr_list, pos_list, pos_ct_list):
        """
        合并所有特征张量，生成最终的PLN格式标签
        
        参数:
            Q_list: 角点类别张量列表
            Q_ct_list: 中心点类别张量列表
            posi_list: 角点存在性张量列表
            posi_ct_list: 中心点存在性张量列表
            Link_ct_list: 中心点链接张量列表
            Link_cr_list: 角点链接张量列表
            pos_list: 角点相对位置张量列表
            pos_ct_list: 中心点相对位置张量列表
            
        返回:
            Tensor: 形状为[14,14,204]的最终标签张量
        """
        list_feature = []
        list_ct_feature = []
        
        # 组合每个边界框的特征
        for i in range(self.B):
            # 组合中心点特征: [存在性, 相对位置, 连接关系, 类别]
            center_features = torch.cat((
                posi_ct_list[i], 
                pos_ct_list[i],
                Link_ct_list[i],
                Q_ct_list[i]
            ), dim=-1)
            
            # 组合角点特征: [存在性, 相对位置, 连接关系, 类别]
            corner_features = torch.cat((
                posi_list[i], 
                pos_list[i], 
                Link_cr_list[i], 
                Q_list[i]
            ), dim=-1)
            
            list_feature.append(corner_features)
            list_ct_feature.append(center_features)
        
        # 合并所有角点特征
        feature_tensor = torch.cat(list_feature, dim=-1)
        # 合并所有中心点特征
        feature_ct_tensor = torch.cat(list_ct_feature, dim=-1)
        
        # 合并中心点和角点特征，得到最终标签张量
        return torch.cat((feature_ct_tensor, feature_tensor), dim=-1)


# 为了保持兼容性，创建一个别名
VocDataset = PLNDataset


if __name__ == "__main__":
    # 测试代码
    pln_dataset = PLNDataset(anno_file_type="train", img_size=448)
    dataloader = DataLoader(pln_dataset, shuffle=True, batch_size=1, num_workers=0)

    for i, (img, target) in enumerate(dataloader):
        if i == 0:
            print("Image shape:", img.shape)
            print("Target shape:", target.shape)
            break